{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MF7BncmmLBeO"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "from sklearn.datasets import load_digits\n",
    "from sklearn import datasets\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as tt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**DISCLAIMER**\n",
    "\n",
    "The presented code is not optimized, it serves an educational purpose. It is written for CPU, it uses only fully-connected networks and an extremely simplistic dataset. However, it contains all components that can help to understand how an example of score-based generative models works, and it should be rather easy to extend it to more sophisticated models. This code could be run almost on any laptop/PC, and it takes a couple of minutes top to get the result."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "RKsmjLumL5A2"
   },
   "source": [
    "## Dataset: Digits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this example, we go wild and use a dataset that is simpler than MNIST! We use a scipy dataset called Digits. It consists of ~1500 images of size 8x8, and each pixel can take values in $\\{0, 1, \\ldots, 16\\}$.\n",
    "\n",
    "The goal of using this dataset is that everyone can run it on a laptop, without any gpu etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hSWUnXAYLLif"
   },
   "outputs": [],
   "source": [
    "class Digits(Dataset):\n",
    "    \"\"\"Scikit-Learn Digits dataset.\"\"\"\n",
    "\n",
    "    def __init__(self, mode='train', transforms=None):\n",
    "        digits = load_digits()\n",
    "        if mode == 'train':\n",
    "            self.data = digits.data[:1000].astype(np.float32)\n",
    "        elif mode == 'val':\n",
    "            self.data = digits.data[1000:1350].astype(np.float32)\n",
    "        else:\n",
    "            self.data = digits.data[1350:].astype(np.float32)\n",
    "\n",
    "        self.transforms = transforms\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        sample = self.data[idx]\n",
    "        if self.transforms:\n",
    "            sample = self.transforms(sample)\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qSP2qiMqMICK"
   },
   "source": [
    "## Score-based Generative Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SBGM(nn.Module):\n",
    "    def __init__(self, snet, sigma, D, T):\n",
    "        super(SBGM, self).__init__()\n",
    "\n",
    "        print(\"SBGM by JT.\")\n",
    "        \n",
    "        # sigma parameter\n",
    "        self.sigma = torch.Tensor([sigma])\n",
    "        \n",
    "        # define the base distribution (multivariate Gaussian with the diagonal covariance)\n",
    "        var = (1./(2.* torch.log(self.sigma))) * (self.sigma**2 - 1.)\n",
    "        self.base = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(D), var * torch.eye(D))\n",
    "        \n",
    "        # score model\n",
    "        self.snet = snet\n",
    "        \n",
    "        # time embedding (a single linear layer)\n",
    "        self.time_embedding = nn.Sequential(nn.Linear(1, D), nn.Tanh())\n",
    "        \n",
    "        # other hyperparams\n",
    "        self.D = D\n",
    "        \n",
    "        self.T = T\n",
    "                \n",
    "        self.EPS = 1.e-5\n",
    "        \n",
    "    def sigma_fun(self, t):\n",
    "        # the sigma function (dependent on t), it is the std of the distribution\n",
    "        return torch.sqrt((1./(2. * torch.log(self.sigma))) * (self.sigma**(2.*t) - 1.))\n",
    "\n",
    "    def log_p_base(self, x):\n",
    "        # the log-probability of the base distribition, p_1(x)\n",
    "        log_p = self.base.log_prob(x)\n",
    "        return log_p\n",
    "    \n",
    "    def sample_base(self, x_0):\n",
    "        # sampling from the base distribution\n",
    "        return self.base.rsample(sample_shape=torch.Size([x_0.shape[0]]))\n",
    "        \n",
    "    def sample_p_t(self, x_0, x_1, t):\n",
    "        # sampling from p_0t(x_t|x_0)\n",
    "        # x_0 ~ data, x_1 ~ noise\n",
    "        x = x_0 + self.sigma_fun(t) * x_1\n",
    "        \n",
    "        return x\n",
    "    \n",
    "    def lambda_t(self, t):\n",
    "        # the loss weighting\n",
    "        return self.sigma_fun(t)**2\n",
    "    \n",
    "    def diffusion_coeff(self, t):\n",
    "        # the diffusion coefficient in the SDE\n",
    "        return self.sigma**t\n",
    "    \n",
    "    def forward(self, x_0, reduction='mean'):\n",
    "        # =====\n",
    "        # x_1 ~ the base distribiution\n",
    "        x_1 = torch.randn_like(x_0)\n",
    "        # t ~ Uniform(0, 1)\n",
    "        t = torch.rand(size=(x_0.shape[0], 1))  * (1. - self.EPS) + self.EPS \n",
    "        \n",
    "        # =====\n",
    "        # sample from p_0t(x|x_0)\n",
    "        x_t = self.sample_p_t(x_0, x_1, t)\n",
    "\n",
    "        # =====\n",
    "        # invert noise\n",
    "        # NOTE: here we use the correspondence eps_theta(x,t) = -sigma*t score_theta(x,t)\n",
    "        t_embd = self.time_embedding(t)\n",
    "        x_pred = -self.sigma_fun(t) * self.snet(x_t + t_embd)\n",
    "\n",
    "        # =====LOSS: Score Matching\n",
    "        # NOTE: since x_pred is the predicted noise, and x_1 is noise, this corresponds to Noise Matching \n",
    "        #       (i.e., the loss used in diffusion-based models by Ho et al.)\n",
    "        SM_loss = 0.5 * self.lambda_t(t) * torch.pow(x_pred + x_1, 2).mean(-1)\n",
    "        \n",
    "        if reduction == 'sum':\n",
    "            loss = SM_loss.sum()\n",
    "        else:\n",
    "            loss = SM_loss.mean()\n",
    "\n",
    "        return loss\n",
    "\n",
    "    def sample(self, batch_size=64):\n",
    "        # 1) sample x_0 ~ Normal(0,1/(2log sigma) * (sigma**2 - 1))\n",
    "        x_t = self.sample_base(torch.empty(batch_size, self.D))\n",
    "        \n",
    "        # Apply Euler's method\n",
    "        # NOTE: x_0 - data, x_1 - noise\n",
    "        #       Therefore, we must use BACKWARD Euler's method! This results in the minus sign! \n",
    "        ts = torch.linspace(1., self.EPS, self.T)\n",
    "        delta_t = ts[0] - ts[1]\n",
    "        \n",
    "        for t in ts[1:]:\n",
    "            tt = torch.Tensor([t])\n",
    "            u = 0.5 * self.diffusion_coeff(tt) * self.snet(x_t + self.time_embedding(tt))\n",
    "            x_t = x_t - delta_t * u\n",
    "        \n",
    "        x_t = torch.tanh(x_t)\n",
    "        return x_t\n",
    "    \n",
    "    def log_prob_proxy(self, x_0, reduction=\"mean\"):\n",
    "        # Calculate the proxy of the log-likelihood (see (Song et al., 2021))\n",
    "        # NOTE: Here, we use a single sample per time step (this is done only for simplicity and speed);\n",
    "        # To get a better estimate, we should sample more noise\n",
    "        ts = torch.linspace(self.EPS, 1., self.T)\n",
    "\n",
    "        for t in ts:\n",
    "            # Sample noise\n",
    "            x_1 = torch.randn_like(x_0)\n",
    "            # Sample from p_0t(x_t|x_0)\n",
    "            x_t = self.sample_p_t(x_0, x_1, t)\n",
    "            # Predict noise\n",
    "            t_embd = self.time_embedding(torch.Tensor([t]))\n",
    "            x_pred = -self.snet(x_t + t_embd) * self.sigma_fun(t)\n",
    "            # loss (proxy)          \n",
    "            if t == self.EPS:\n",
    "                proxy = 0.5 * self.lambda_t(t) * torch.pow(x_pred + x_1, 2).mean(-1)\n",
    "            else:\n",
    "                proxy = proxy + 0.5 * self.lambda_t(t) * torch.pow(x_pred + x_1, 2).mean(-1)\n",
    "            \n",
    "        if reduction == \"mean\":\n",
    "            return proxy.mean()\n",
    "        elif reduction == \"sum\":\n",
    "            return proxy.sum()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vUoPkTmrMVnx"
   },
   "source": [
    "## Evaluation and Training functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JvwmRoi7MVto"
   },
   "source": [
    "**Evaluation step, sampling and curve plotting**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JHx4RIqDLZe9"
   },
   "outputs": [],
   "source": [
    "def evaluation(test_loader, name=None, model_best=None, epoch=None):\n",
    "    # EVALUATION\n",
    "    if model_best is None:\n",
    "        # load best performing model\n",
    "        model_best = torch.load(name + '.model')\n",
    "\n",
    "    model_best.eval()\n",
    "    loss = 0.\n",
    "    N = 0.\n",
    "    for indx_batch, test_batch in enumerate(test_loader):\n",
    "        loss_t = model_best.log_prob_proxy(test_batch, reduction='sum')\n",
    "        loss = loss + loss_t.item()\n",
    "        N = N + test_batch.shape[0]\n",
    "    loss = loss / N\n",
    "\n",
    "    if epoch is None:\n",
    "        print(f'FINAL LOSS: nll={loss}')\n",
    "    else:\n",
    "        print(f'Epoch: {epoch}, val nll={loss}')\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "def samples_real(name, test_loader):\n",
    "    # REAL-------\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = next(iter(test_loader)).detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "\n",
    "def samples_generated(name, data_loader, extra_name='', T=None):\n",
    "    # GENERATIONS-------\n",
    "    model_best = torch.load(name + '.model')\n",
    "    model_best.eval()\n",
    "    \n",
    "    if T is not None:\n",
    "        model_best.T = T\n",
    "\n",
    "    num_x = 4\n",
    "    num_y = 4\n",
    "    x = model_best.sample(batch_size=num_x * num_y)\n",
    "    x = x.detach().numpy()\n",
    "\n",
    "    fig, ax = plt.subplots(num_x, num_y)\n",
    "    for i, ax in enumerate(ax.flatten()):\n",
    "        plottable_image = np.reshape(x[i], (8, 8))\n",
    "        ax.imshow(plottable_image, cmap='gray')\n",
    "        ax.axis('off')\n",
    "\n",
    "    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')\n",
    "    plt.close()\n",
    "\n",
    "def plot_curve(name, nll_val):\n",
    "    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel('proxy')\n",
    "    plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight')\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "umU3VYKzMbDt"
   },
   "source": [
    "**Training step**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "NxkUZ1xVLbm_"
   },
   "outputs": [],
   "source": [
    "def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):\n",
    "    nll_val = []\n",
    "    best_nll = 1000.\n",
    "    patience = 0\n",
    "\n",
    "    # Main loop\n",
    "    for e in range(num_epochs):\n",
    "        # TRAINING\n",
    "        model.train()\n",
    "        for indx_batch, batch in enumerate(training_loader):\n",
    "            loss = model.forward(batch)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward(retain_graph=True)\n",
    "            optimizer.step()\n",
    "\n",
    "        # Validation\n",
    "        loss_val = evaluation(val_loader, model_best=model, epoch=e)\n",
    "        nll_val.append(loss_val)  # save for plotting\n",
    "\n",
    "        if e == 0:\n",
    "            print('saved!')\n",
    "            torch.save(model, name + '.model')\n",
    "            best_nll = loss_val\n",
    "        else:\n",
    "            if loss_val < best_nll:\n",
    "                print('saved!')\n",
    "                torch.save(model, name + '.model')\n",
    "                best_nll = loss_val\n",
    "                patience = 0\n",
    "\n",
    "                samples_generated(name, val_loader, extra_name=\"_epoch_\" + str(e))\n",
    "            else:\n",
    "                patience = patience + 1\n",
    "        \n",
    "        if patience > max_patience:\n",
    "            break\n",
    "\n",
    "    nll_val = np.asarray(nll_val)\n",
    "\n",
    "    return nll_val"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0BXJ9dN0MinB"
   },
   "source": [
    "## Experiments"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KsF7f-Q-MkWu"
   },
   "source": [
    "**Initialize datasets**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transforms = tt.Lambda(lambda x: 2. * (x / 17.) - 1.)  # changing to [-1, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fqZKMNM0LdQ1"
   },
   "outputs": [],
   "source": [
    "train_data = Digits(mode='train', transforms=transforms)\n",
    "val_data = Digits(mode='val', transforms=transforms)\n",
    "test_data = Digits(mode='test', transforms=transforms)\n",
    "\n",
    "training_loader = DataLoader(train_data, batch_size=32, shuffle=True)\n",
    "val_loader = DataLoader(val_data, batch_size=32, shuffle=False)\n",
    "test_loader = DataLoader(test_data, batch_size=32, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6lEKUznpMns7"
   },
   "source": [
    "**Hyperparameters**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ANQo7LrGLjIN"
   },
   "outputs": [],
   "source": [
    "prob_path = \"sbgm\"\n",
    "\n",
    "D = 64   # input dimension\n",
    "\n",
    "M = 512  # the number of neurons in scale (s) and translation (t) nets\n",
    "\n",
    "T = 20\n",
    "\n",
    "sigma = 1.01\n",
    "\n",
    "lr = 1e-3 # learning rate\n",
    "num_epochs = 1000 # max. number of epochs\n",
    "max_patience = 50 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-7APXeunMrDh"
   },
   "source": [
    "**Creating a folder for results**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bjSUn1eWLkWm"
   },
   "outputs": [],
   "source": [
    "name = prob_path + '_' + str(T)\n",
    "result_dir = 'results/' + name + '/'\n",
    "if not (os.path.exists(result_dir)):\n",
    "    os.mkdir(result_dir)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Hpwm6LWUMulQ"
   },
   "source": [
    "**Initializing the model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "FrnNsCqQLmK3",
    "outputId": "5f0cf2b1-0a96-4f5c-da9e-f78f909a5259"
   },
   "outputs": [],
   "source": [
    "nnet = nn.Sequential(nn.Linear(D, M), nn.SiLU(),\n",
    "                     nn.Linear(M, M), nn.SiLU(),\n",
    "                     nn.Linear(M, M), nn.SiLU(),\n",
    "                     nn.Linear(M, D), nn.Hardtanh(min_val=-3., max_val=3.))\n",
    "\n",
    "# Eventually, we initialize the full model\n",
    "model = SBGM(nnet, sigma=sigma, D=D, T=T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3SzTemY3NSxO"
   },
   "source": [
    "**Optimizer - here we use Adamax**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "R9TZtLVtLoWc"
   },
   "outputs": [],
   "source": [
    "# OPTIMIZER\n",
    "optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dNf__W_ONVHA"
   },
   "source": [
    "**Training loop**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "KhqHgluGLqIC",
    "outputId": "c52fa1e4-3376-4bff-9f87-6f03613c4e42"
   },
   "outputs": [],
   "source": [
    "# Training procedure\n",
    "nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,\n",
    "                       training_loader=training_loader, val_loader=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-3XTxgEcNXfp"
   },
   "source": [
    "**The final evaluation**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "okK1mV_-LrRU",
    "outputId": "4664693f-742d-4453-94cf-d051d2efa9be"
   },
   "outputs": [],
   "source": [
    "test_loss = evaluation(name=result_dir + name, test_loader=test_loader)\n",
    "f = open(result_dir + name + '_test_loss.txt', \"w\")\n",
    "f.write(str(test_loss))\n",
    "f.close()\n",
    "\n",
    "samples_real(result_dir + name, test_loader)\n",
    "samples_generated(result_dir + name, test_loader, extra_name='FINAL')\n",
    "\n",
    "plot_curve(result_dir + name, nll_val)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "vae_priors.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
